import numpy as np
from math import *
from env import dynamics
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
torch.manual_seed(2)
np.random.seed(2)

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.parameters = [
                    torch.Tensor(512, 2).uniform_(-1./sqrt(2), 1./sqrt(2)).requires_grad_(),
                    torch.Tensor(512).zero_().requires_grad_(),

                    torch.Tensor(256, 512).uniform_(-1./sqrt(512), 1./sqrt(512)).requires_grad_(),
                    torch.Tensor(256).zero_().requires_grad_(),

                    torch.Tensor(256, 256).uniform_(-1./sqrt(256), 1./sqrt(256)).requires_grad_(),
                    torch.Tensor(256).zero_().requires_grad_(),

                    torch.Tensor(128, 256).uniform_(-1./sqrt(256), 1./sqrt(256)).requires_grad_(),
                    torch.Tensor(128).zero_().requires_grad_(),

                    torch.Tensor(1, 128).uniform_(-1./sqrt(128), 1./sqrt(128)).requires_grad_(),
                    torch.Tensor(1).zero_().requires_grad_(),
                ]

    def dense(self, x, parameters):
        x = F.linear(x, parameters[0], parameters[1])
        x = F.relu(x)

        x = F.linear(x, parameters[2], parameters[3])
        x = F.relu(x)

        x = F.linear(x, parameters[4], parameters[5])
        x = F.relu(x)

        x = F.linear(x, parameters[6], parameters[7])
        x = F.relu(x)

        x = F.linear(x, parameters[8], parameters[9])
        return x

def policy(Q_matrix,num_action):
  distribution=np.zeros((10,13,num_action))
  distribution=distribution.astype(np.object)
  for x in range(10):
    for y in range(13):
      counter=0
      value_list=[]
      for a in range(num_action):
        value_list.append(Q_matrix[x][y][a])
      max_value=max(value_list)
      for a in range(num_action):
        if Q_matrix[x][y][a]==max_value:
          counter=counter+1
      for a in range(num_action):
        if Q_matrix[x][y][a]==max_value:
          distribution[x][y][a]=1.0/counter
  return distribution

def Q_matrix_function(gamma,V_matrix,num_action,reward_matrix,cost_matrix):
  Q_matrix=np.zeros((10,13,num_action))
  Q_matrix=Q_matrix.astype(np.object)
  for x in range(10):
    for y in range(13):
      for a in range(num_action):
        next_state=dynamics(np.mat([x,y]).T,np.mat([a]).T)
        value=V_matrix[next_state.item(0)][next_state.item(1)]
        Q_matrix[x][y][a]=reward_matrix[x,y]-cost_matrix[x,y]+gamma*value
  return Q_matrix

def V_matrix_funciton(Q_matrix,num_action,policy):
  V_matrix=np.zeros((10,13))
  V_matrix=V_matrix.astype(np.object)
  for x in range(10):
    for y in range(13):
      value=0.0
      for a in range(num_action):
        value=value+policy[x][y][a]*Q_matrix[x][y][a]
      V_matrix[x][y]=value
  return V_matrix

def calculate_policy(reward_function1,reward_function2,omega,gamma,num_action):
  reward1_matrix=np.zeros((10,13))
  reward1_matrix=reward1_matrix.astype(np.object)
  reward2_matrix=np.zeros((10,13))
  reward2_matrix=reward2_matrix.astype(np.object)
  for x in range(10):
    for y in range(13):
      reward1_matrix[x,y]=reward_function1.dense(torch.tensor([1.0*x/10,(13.0-1.0)*y/13]),reward_function1.parameters).item()
      reward2_matrix[x,y]=reward_function2.dense(torch.tensor([(9.0-1.0*x)/10,1.0*y/13]),reward_function2.parameters).item()
  print('reward1_matrix',reward1_matrix)
  #print('reward2_matrix',reward2_matrix)

  cost_matrix=100*omega

  V1_matrix=np.zeros((10,13))
  V1_matrix=V1_matrix.astype(np.object)
  V2_matrix=np.zeros((10,13))
  V2_matrix=V2_matrix.astype(np.object)
 
  Q1_matrix=np.copy(Q_matrix_function(gamma,V1_matrix,num_action,reward1_matrix,cost_matrix))
  policy1=np.copy(policy(Q1_matrix,num_action))
  new_V1_matrix=np.copy(V_matrix_funciton(Q1_matrix,num_action,policy1))
  Q2_matrix=np.copy(Q_matrix_function(gamma,V2_matrix,num_action,reward2_matrix,cost_matrix))
  policy2=np.copy(policy(Q2_matrix,num_action))
  new_V2_matrix=np.copy(V_matrix_funciton(Q2_matrix,num_action,policy2))

  for m in range(50):
    V1_matrix=np.copy(new_V1_matrix)
    Q1_matrix=np.copy(Q_matrix_function(gamma,V1_matrix,num_action,reward1_matrix,cost_matrix))
    policy1=np.copy(policy(Q1_matrix,num_action))
    new_V1_matrix=np.copy(V_matrix_funciton(Q1_matrix,num_action,policy1))
    V2_matrix=np.copy(new_V2_matrix)
    Q2_matrix=np.copy(Q_matrix_function(gamma,V2_matrix,num_action,reward2_matrix,cost_matrix))
    policy2=np.copy(policy(Q2_matrix,num_action))
    new_V2_matrix=np.copy(V_matrix_funciton(Q2_matrix,num_action,policy2))

  return policy1,policy2

def trial(initial_state,policy1,policy2,num_action,iteration):
  trajectory=[]
  state=initial_state
  for i in range(35):
    policy1_distribution=policy1[state.item(0)][state.item(1)][:]
    choice1=[]
    for a in range(num_action):
      if policy1_distribution[a]>0.0:
        choice1.append(a)
    #if iteration<9:
    #  index1=np.random.randint(len(choice1))
    #  action1=choice1[index1]
    #else:
    sign1=np.random.uniform()
    if sign1>(0.1/iteration):
      index1=np.random.randint(len(choice1))
      action1=choice1[index1]
    else:
      action1=np.random.randint(0,4)
    next_state1=dynamics(state[0:2],np.mat([action1]).T)

    policy2_distribution=policy2[state.item(2)][state.item(3)][:]
    choice2=[]
    for a in range(num_action):
      if policy2_distribution[a]>0.0:
        choice2.append(a)
    #if iteration<9:
    #  index2=np.random.randint(len(choice2))
    #  action2=choice2[index2]
    #else:
    sign2=np.random.uniform()
    if sign2>(0.1/iteration):
      index2=np.random.randint(len(choice2))
      action2=choice2[index2]
    else:
      action2=np.random.randint(0,4)
    next_state2=dynamics(state[2:4],np.mat([action2]).T)
    
    trajectory.append([state.item(0),state.item(1),state.item(2),state.item(3),action1,action2])
    state=np.copy(np.vstack((next_state1,next_state2)))
  return trajectory

def constraint_map(number_trials,trajectories):
  constraint_map=np.zeros((10,13))
  constraint_map=constraint_map.astype(np.object)
  for i in range(number_trials):
    for j in range(35):
      constraint_map[int(trajectories[35*i+j,0]),int(trajectories[35*i+j,1])]=constraint_map[int(trajectories[35*i+j,0]),int(trajectories[35*i+j,1])]+1.0
      constraint_map[int(trajectories[35*i+j,2]),int(trajectories[35*i+j,3])]=constraint_map[int(trajectories[35*i+j,2]),int(trajectories[35*i+j,3])]+1.0
  return constraint_map/number_trials

def reward_gradient_map(reward_function1,reward_function2,number_trials,trajectories):
  reward_parameters_count1=[0.0]*10
  reward_parameters_count2=[0.0]*10
  for i in range(number_trials):
    for j in range(35):
      reward_value1=reward_function1.dense(torch.tensor([trajectories[35*i+j,0],trajectories[35*i+j,1]],dtype=torch.float),reward_function1.parameters)
      reward_value1.backward()
      for number in range(10):
        reward_parameters_count1[number]=reward_parameters_count1[number]+reward_function1.parameters[number].grad
      reward_value2=reward_function2.dense(torch.tensor([trajectories[35*i+j,2],trajectories[35*i+j,3]],dtype=torch.float),reward_function2.parameters)
      reward_value2.backward()
      for number in range(10):
        reward_parameters_count2[number]=reward_parameters_count2[number]+reward_function2.parameters[number].grad
  for number in range(10):
    reward_parameters_count1[number]=reward_parameters_count1[number]/number_trials
    reward_parameters_count2[number]=reward_parameters_count2[number]/number_trials
  return reward_parameters_count1,reward_parameters_count2

def false_positive_negative_rate(omega):
  positive=0.0
  false_positive=0.0
  for x in range(10):
    for y in range(13):
      if omega[x,y]>0.0:
        if x>=3 and x<=9 and y>=2 and y<=10:
          positive=positive+1.0
        elif x==0 and y>=1 and y<=5:
          positive=positive+1.0
        elif x==0 and y>=7 and y<=11:
          positive=positive+1.0
        elif x==8 and y==0:
          positive=positive+1.0
        elif x==2 and y==1:
          positive=positive+1.0
        elif x==3 and y==11:
          positive=positive+1.0
        elif x==7 and y==12:
          positive=positive+1.0
        else:
          false_positive=false_positive+1.0
  return false_positive/53.0, (77.0-positive)/77.0

def obstacle_collision(x,y):
  if x>=3 and x<=9 and y>=2 and y<=10:
    return True
  elif x==0 and y>=1 and y<=5:
    return True
  elif x==0 and y>=7 and y<=11:
    return True
  elif x==8 and y==0:
    return True
  elif x==2 and y==1:
    return True
  elif x==3 and y==11:
    return True
  elif x==7 and y==12:
    return True
  else:
    return False

def constraint_violation_rate(number_trials,trajectories):
  violation=0.0
  for i in range(number_trials):
    for j in range(35):
      if obstacle_collision(trajectories[35*i+j,0],trajectories[35*i+j,1]):
        violation=violation+1.0
        break
    for j in range(35):
      if obstacle_collision(trajectories[35*i+j,2],trajectories[35*i+j,3]):
        violation=violation+1.0
        break
  return violation/(2.0*number_trials)

def success_rate(number_trials,trajectories):
  success=0.0
  for i in range(number_trials):
    for j in range(35):
      if obstacle_collision(trajectories[35*i+j,0],trajectories[35*i+j,1]):
        break
      elif trajectories[35*i+j,0]==9 and trajectories[35*i+j,1]==12:
        success=success+1.0
        break
    for j in range(35):
      if obstacle_collision(trajectories[35*i+j,2],trajectories[35*i+j,3]):
        break 
      elif trajectories[35*i+j,2]==9 and trajectories[35*i+j,3]==0:
        success=success+1.0 
        break
  return success/(2.0*num_trials)

num_action=4
gamma=1.0
num_trials=50
a1=np.loadtxt("expert_trajectory_file1.txt",dtype=float)
a2=np.loadtxt("expert_trajectory_file2.txt",dtype=float)
expert_trajectories1=a1.reshape(35*20,6)
expert_trajectories2=a2.reshape(35*20,6)
initial_state=np.mat([9,0,9,12]).T

def experiment():  
  false_positive_list1=[]
  false_negative_list1=[]
  constraint_violation_list1=[]
  success_list1=[]
  omega1=np.zeros((10,13))
  omega1=omega1.astype(np.object)
  reward_function11=Model()
  reward_function12=Model()

  false_positive_list2=[]
  false_negative_list2=[]
  constraint_violation_list2=[]
  success_list2=[]
  omega2=np.zeros((10,13))
  omega2=omega1.astype(np.object)
  reward_function21=Model()
  reward_function22=Model()

  for i in range(20):
    print('online iteration', i+1)
    expert_constraint_map1=1000*constraint_map(i+1,expert_trajectories1)
    expert_constraint_map2=1000*constraint_map(i+1,expert_trajectories2)
    #print('omega', omega1)
    policy11,policy12=calculate_policy(reward_function11,reward_function12,omega1,gamma,num_action)
    trajectory_file1=open("learner_trajectory_file1.txt","w")
    for j in range(num_trials):
      trajectory1=np.copy(trial(initial_state,policy11,policy12,num_action,i+1))
      for entry in trajectory1:
        np.savetxt(trajectory_file1,entry)
    trajectory_file1.close()
    b1=np.loadtxt("learner_trajectory_file1.txt",dtype=float)
    learner_trajectories1=b1.reshape(35*num_trials,6)
    learner_constraint_map1=constraint_map(num_trials,learner_trajectories1)

    policy21,policy22=calculate_policy(reward_function21,reward_function22,omega2,gamma,num_action)
    trajectory_file2=open("learner_trajectory_file2.txt","w")
    for j in range(num_trials):
      trajectory2=np.copy(trial(initial_state,policy21,policy22,num_action,i+1))
      for entry in trajectory2:
        np.savetxt(trajectory_file2,entry)
    trajectory_file2.close()
    b2=np.loadtxt("learner_trajectory_file2.txt",dtype=float)
    learner_trajectories2=b2.reshape(35*num_trials,6)
    learner_constraint_map2=constraint_map(num_trials,learner_trajectories2)

    if i<15:
      step_size=1
    else:
      step_size=0
    omega1=0.5*(omega1+omega2)+0.1*step_size*(learner_constraint_map1-expert_constraint_map1)
    omega2=0.5*(omega1+omega2)+0.1*step_size*(learner_constraint_map2-expert_constraint_map2)

    for x in range(10):
      for y in range(13):
        if omega1[x,y]>1.0:
          omega1[x,y]=1.0
        if omega1[x,y]<0.0:
          omega1[x,y]=0.0
        if omega2[x,y]>1.0:
          omega2[x,y]=1.0
        if omega2[x,y]<0.0:
          omega2[x,y]=0.0
      
    expert_reward_gradient_map11,expert_reward_gradient_map12=reward_gradient_map(reward_function11,reward_function12,i+1,expert_trajectories1)
    learner_reward_gradient_map11,learner_reward_gradient_map12=reward_gradient_map(reward_function11,reward_function12,num_trials,learner_trajectories1)

    false_positive_rate1,false_negative_rate1=false_positive_negative_rate(omega1)
    constraint_violation1=constraint_violation_rate(num_trials,learner_trajectories1)
    success1=success_rate(num_trials,learner_trajectories1) 

    false_positive_list1.append(false_positive_rate1)
    false_negative_list1.append(false_negative_rate1)
    constraint_violation_list1.append(constraint_violation1)
    success_list1.append(success1)

    print('false positive rate1', false_positive_rate1)
    print('false negative rate1', false_negative_rate1)
    print('constraint violation1', constraint_violation1)
    print('success_rate1', success1)

    expert_reward_gradient_map21,expert_reward_gradient_map22=reward_gradient_map(reward_function21,reward_function22,i+1,expert_trajectories2)
    learner_reward_gradient_map21,learner_reward_gradient_map22=reward_gradient_map(reward_function21,reward_function22,num_trials,learner_trajectories2)

    false_positive_rate2,false_negative_rate2=false_positive_negative_rate(omega2)
    constraint_violation2=constraint_violation_rate(num_trials,learner_trajectories2)
    success2=success_rate(num_trials,learner_trajectories2) 

    false_positive_list2.append(false_positive_rate2)
    false_negative_list2.append(false_negative_rate2)
    constraint_violation_list2.append(constraint_violation2)
    success_list2.append(success2)

    print('false positive rate2', false_positive_rate2)
    print('false negative rate2', false_negative_rate2)
    print('constraint violation2', constraint_violation2)
    print('success_rate2', success2)


    #with torch.no_grad():
    #  for number in range(10):
    #    reward_function11.parameters[number]+=-0.5*reward_function11.parameters[number]+0.5*reward_function21.parameters[number]+0.005*(0.00001*(learner_reward_gradient_map11[number]-expert_reward_gradient_map11[number]))
    #    reward_function12.parameters[number]-=0.005*(0.00001*(learner_reward_gradient_map12[number]-expert_reward_gradient_map12[number]))

    #with torch.no_grad():
    #  for number in range(10):
    #    reward_function21.parameters[number]+=0.005*(0.00001*(learner_reward_gradient_map21[number]-expert_reward_gradient_map21[number]))
    #    reward_function22.parameters[number]-=0.005*(0.00001*(learner_reward_gradient_map22[number]-expert_reward_gradient_map22[number]))

    for number in range(10):
      a=0.5*(reward_function11.parameters[number].data+reward_function21.parameters[number].data)
      b=0.5*(reward_function12.parameters[number].data+reward_function22.parameters[number].data)
      reward_function11.parameters[number].data=a+0.005*(0.00001*(learner_reward_gradient_map11[number]-expert_reward_gradient_map11[number]))
      reward_function12.parameters[number].data=b-0.005*(0.00001*(learner_reward_gradient_map12[number]-expert_reward_gradient_map12[number]))
      reward_function21.parameters[number].data=a+0.005*(0.00001*(learner_reward_gradient_map21[number]-expert_reward_gradient_map21[number]))
      reward_function22.parameters[number].data=b-0.005*(0.00001*(learner_reward_gradient_map22[number]-expert_reward_gradient_map22[number]))

    print(trajectory1)
  return false_positive_list1,false_negative_list1,constraint_violation_list1,success_list1,false_positive_list2,false_negative_list2,constraint_violation_list2,success_list2

all_false_positive_list1=[]
all_false_negative_list1=[]
all_constraint_violation_list1=[]
all_success_list1=[]

all_false_positive_list2=[]
all_false_negative_list2=[]
all_constraint_violation_list2=[]
all_success_list2=[]

num_experiment=2

start_time=time.time()
for number in range(num_experiment):
  false_positive_list1,false_negative_list1,constraint_violation_list1,success_list1,false_positive_list2,false_negative_list2,constraint_violation_list2,success_list2=experiment()
  all_false_positive_list1.append(false_positive_list1)
  all_false_negative_list1.append(false_negative_list1)
  all_constraint_violation_list1.append(constraint_violation_list1)
  all_success_list1.append(success_list1)

  all_false_positive_list2.append(false_positive_list2)
  all_false_negative_list2.append(false_negative_list2)
  all_constraint_violation_list2.append(constraint_violation_list2)
  all_success_list2.append(success_list2)

end_time=time.time()
print('time cost for one experiment',(end_time-start_time)/(2*num_experiment))

false_positive_mean_list1=[]
false_positive_sd_list1=[]
false_negative_mean_list1=[]
false_negative_sd_list1=[]
constraint_violation_mean_list1=[]
constraint_violation_sd_list1=[]
success_mean_list1=[]
success_sd_list1=[]

false_positive_mean_list2=[]
false_positive_sd_list2=[]
false_negative_mean_list2=[]
false_negative_sd_list2=[]
constraint_violation_mean_list2=[]
constraint_violation_sd_list2=[]
success_mean_list2=[]
success_sd_list2=[]

for i in range(20):
  positive_list1=[]
  negative_list1=[]
  violation_list1=[]
  succ_list1=[]
  positive_list2=[]
  negative_list2=[]
  violation_list2=[]
  succ_list2=[]
  for j in range(num_experiment):
    positive_list1.append(all_false_positive_list1[j][i])
    negative_list1.append(all_false_negative_list1[j][i])
    violation_list1.append(all_constraint_violation_list1[j][i])
    succ_list1.append(all_success_list1[j][i])
    positive_list2.append(all_false_positive_list2[j][i])
    negative_list2.append(all_false_negative_list2[j][i])
    violation_list2.append(all_constraint_violation_list2[j][i])
    succ_list2.append(all_success_list2[j][i])

  false_positive_mean_list1.append([sum(positive_list1)/len(positive_list1)])
  false_positive_sd_list1.append([sqrt(np.var(positive_list1))])
  false_negative_mean_list1.append([sum(negative_list1)/len(negative_list1)])
  false_negative_sd_list1.append([sqrt(np.var(negative_list1))])
  constraint_violation_mean_list1.append([sum(violation_list1)/len(violation_list1)])
  constraint_violation_sd_list1.append([sqrt(np.var(violation_list1))])
  success_mean_list1.append([sum(succ_list1)/len(succ_list1)])
  success_sd_list1.append([sqrt(np.var(succ_list1))])

  false_positive_mean_list2.append([sum(positive_list2)/len(positive_list2)])
  false_positive_sd_list2.append([sqrt(np.var(positive_list2))])
  false_negative_mean_list2.append([sum(negative_list2)/len(negative_list2)])
  false_negative_sd_list2.append([sqrt(np.var(negative_list2))])
  constraint_violation_mean_list2.append([sum(violation_list2)/len(violation_list2)])
  constraint_violation_sd_list2.append([sqrt(np.var(violation_list2))])
  success_mean_list2.append([sum(succ_list2)/len(succ_list2)])
  success_sd_list2.append([sqrt(np.var(succ_list2))])

#print(false_positive_mean_list)

false_positive_mean_file1=open("learner1_false_positive_mean_file.txt","w")
for entry in false_positive_mean_list1:
  np.savetxt(false_positive_mean_file1,entry)
false_positive_mean_file1.close()

false_negative_mean_file1=open("learner1_false_negative_mean_file.txt","w")
for entry in false_negative_mean_list1:
  np.savetxt(false_negative_mean_file1,entry)
false_negative_mean_file1.close()

constraint_violation_mean_file1=open("learner1_constraint_violation_mean_file.txt","w")
for entry in constraint_violation_mean_list1:
  np.savetxt(constraint_violation_mean_file1,entry)
constraint_violation_mean_file1.close()

success_mean_file1=open("learner1_success_mean_file.txt","w")
for entry in success_mean_list1:
  np.savetxt(success_mean_file1,entry)
success_mean_file1.close()

false_positive_sd_file1=open("learner1_false_positive_sd_file.txt","w")
for entry in false_positive_sd_list1:
  np.savetxt(false_positive_sd_file1,entry)
false_positive_sd_file1.close()

false_negative_sd_file1=open("learner1_false_negative_sd_file.txt","w")
for entry in false_negative_sd_list1:
  np.savetxt(false_negative_sd_file1,entry)
false_negative_sd_file1.close()

constraint_violation_sd_file1=open("learner1_constraint_violation_sd_file.txt","w")
for entry in constraint_violation_sd_list1:
  np.savetxt(constraint_violation_sd_file1,entry)
constraint_violation_sd_file1.close()

success_sd_file1=open("learner1_success_sd_file.txt","w")
for entry in success_sd_list1:
  np.savetxt(success_sd_file1,entry)
success_sd_file1.close()


false_positive_mean_file2=open("learner2_false_positive_mean_file.txt","w")
for entry in false_positive_mean_list2:
  np.savetxt(false_positive_mean_file2,entry)
false_positive_mean_file2.close()

false_negative_mean_file2=open("learner2_false_negative_mean_file.txt","w")
for entry in false_negative_mean_list2:
  np.savetxt(false_negative_mean_file2,entry)
false_negative_mean_file2.close()

constraint_violation_mean_file2=open("learner2_constraint_violation_mean_file.txt","w")
for entry in constraint_violation_mean_list2:
  np.savetxt(constraint_violation_mean_file2,entry)
constraint_violation_mean_file2.close()

success_mean_file2=open("learner2_success_mean_file.txt","w")
for entry in success_mean_list2:
  np.savetxt(success_mean_file2,entry)
success_mean_file2.close()

false_positive_sd_file2=open("learner2_false_positive_sd_file.txt","w")
for entry in false_positive_sd_list2:
  np.savetxt(false_positive_sd_file2,entry)
false_positive_sd_file2.close()

false_negative_sd_file2=open("learner2_false_negative_sd_file.txt","w")
for entry in false_negative_sd_list2:
  np.savetxt(false_negative_sd_file2,entry)
false_negative_sd_file2.close()

constraint_violation_sd_file2=open("learner2_constraint_violation_sd_file.txt","w")
for entry in constraint_violation_sd_list2:
  np.savetxt(constraint_violation_sd_file2,entry)
constraint_violation_sd_file2.close()

success_sd_file2=open("learner2_success_sd_file.txt","w")
for entry in success_sd_list2:
  np.savetxt(success_sd_file2,entry)
success_sd_file2.close()

















